#!/usr/bin/env python3
"""Updates build/compile_commands.json for clangd to index C++ files.

USAGE:
  scripts/update_compile_commands.py
"""

import dataclasses
import os
import sys
import subprocess
import json

from typing import Optional

# The implementation is largely borrowed from OpenXLA, with some usability improvements
# and fixes specific to PyTorch/XLA.

_JSONDict = dict[str, str]  # Approximates parsed JSON

# The repo root directory.
_REPO_ROOT: str = os.path.abspath(os.path.dirname(__file__) + '/..')

# Where to write the compile_commands.json file.
_COMPILE_DB_PATH: str = os.path.join(_REPO_ROOT, 'build/compile_commands.json')

# Flags to filter from the C++ compiler commandline.
_DISALLOWED_ARGS = frozenset(["-fno-canonical-system-headers"])


@dataclasses.dataclass
class CompileCommand:
  """Represents a compilation command with options on a specific C++ file."""

  file: str
  arguments: list[str]

  @classmethod
  def from_args_list(cls, args_list: list[str]) -> Optional["CompileCommand"]:
    """Alternative constructor which uses the args_list from `bazel aquery`.

    This collects arguments and the file being run on from the output of
    `bazel aquery`. Also filters out arguments which break clang-tidy.

    Arguments:
      args_list: List of arguments generated by `bazel aquery`

    Returns:
      The corresponding ClangTidyCommand.
    """
    cc_file = None
    filtered_args = []

    for arg in args_list:
      if arg in _DISALLOWED_ARGS:
        continue

      last_dot_index = arg.rfind('.')
      if last_dot_index >= 0:
        extension = arg[last_dot_index + 1:]
        if extension in ("cc", "cpp", "c"):
          cc_file = arg

      filtered_args.append(arg)

    if cc_file:
      return cls(cc_file, filtered_args)

    return None

  def to_dumpable_json(self, directory: str) -> _JSONDict:
    return {
        "directory": directory,
        "file": self.file,
        "arguments": self.arguments,
    }


def extract_cc_compile_commands(
    parsed_aquery_output: _JSONDict,) -> list[CompileCommand]:
  """Gathers C++ compile commands to run from `bazel aquery` JSON output.

  Arguments:
    parsed_aquery_output: Parsed JSON representing the output of `bazel aquery
      --output=jsonproto`.

  Returns:
    The list of CompileCommands for C++ files.
  """
  commands = []
  for action in parsed_aquery_output["actions"]:
    command = CompileCommand.from_args_list(action["arguments"])
    if command:
      commands.append(command)
  return commands


def main() -> None:
  if len(sys.argv) != 1:
    sys.exit(__doc__)

  os.chdir(_REPO_ROOT)

  print(
      "Rebuilding the repo to ensure that all external repos and "
      "generated files are available locally when generating "
      "compile_commands.json. This may take several minutes...",
      file=sys.stderr)
  subprocess.run(['bazel', 'build', '--keep_going', '//...'],
                 stdout=subprocess.PIPE,
                 stderr=subprocess.PIPE)

  print("Querying bazel for the CppCompile actions...", file=sys.stderr)
  aquery_result = subprocess.run([
      'bazel', 'aquery', '--keep_going', 'mnemonic(CppCompile, //...)',
      '--output=jsonproto'
  ],
                                 stdout=subprocess.PIPE,
                                 stderr=subprocess.PIPE)
  aquery_output = aquery_result.stdout.decode('utf-8')
  aquery_json = json.loads(aquery_output)

  print(f"Generating {_COMPILE_DB_PATH}...", file=sys.stderr)
  commands = extract_cc_compile_commands(aquery_json)
  with open(_COMPILE_DB_PATH, "w") as f:
    json.dump(
        [
            command.to_dumpable_json(directory=str(_REPO_ROOT))
            for command in commands
        ],
        f,
    )


if __name__ == '__main__':
  main()
